import matplotlib.pyplot as plt
import os
import matplotlib
matplotlib.rcParams['pdf.fonttype']=42
matplotlib.rcParams['ps.fonttype']=42

def show_wave(x, y, x_label, y_label, file_name):
    plt.figure(figsize=(6, 4))
    plt.plot(x, y, linewidth=2.25, color="blue", marker='o', markersize=8)

    plt.yticks(fontproperties='Times New Roman', fontsize=24)
    plt.xticks(fontproperties='Times New Roman', fontsize=24)
    plt.xlabel(x_label, fontdict={'family': 'Times New Roman', 'size': 24})
    plt.ylabel(y_label, fontdict={'family': 'Times New Roman', 'size': 24})

    plt.subplots_adjust(bottom=0.15, left=0.15, right=0.95, top=0.95)
    plt.xticks([0, .5, 1, 1.5, 2], ['0.0',  '0.5', '1.0','1.5', '2.0'])
    # plt.ticklabel_format(axis="y", style="sci", scilimits=(2, 5))
    if y_label == 'AUC':
        plt.yticks([0.0, 0.04, 0.08, 0.12, 0.16], ['0', '0.4', '0.8', '1.2', '1.6'],
                   fontproperties='Times New Roman', fontsize=24)
        plt.text(-0.1, 0.1721, "1e-2", fontsize=20,  fontfamily='Times New Roman')
    if y_label == 'MSE':
        plt.yticks([0.077, 0.078, 0.079, 0.08, 0.081], ['7.7', '7.8', '7.9', '8.0', '8.1'],
                   fontproperties='Times New Roman', fontsize=24)
        plt.text(-0.1, 0.0811, "1e-2", fontsize=20,  fontfamily='Times New Roman')
    # plt.set_visible(False)
    plt.yticks(fontproperties='Times New Roman', fontsize=24)
    plt.xticks(fontproperties='Times New Roman', fontsize=24)
    plt.tight_layout()
    # plt.gca().yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f'))

    plt.savefig(os.path.join('fig_gen', file_name + ".pdf"), format='pdf', dpi=600)
    # plt.show()


if __name__ == '__main__':
    # todo 修改成guri的，下面的结果是gru的
    x = [0,0.125,0.25,1,2]
    y = [0.0767, 0.0788,0.0797,0.0805,0.0805]
    show_wave(x, y, 'γ', 'MSE', 'lambda_p')
    # #
    y = [0.0123,0.1024,0.1388,0.1605,0.1631]
    show_wave(x, y, 'γ', 'AUC', 'lambda_d')
